{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/ai-platform-samples/blob/master/notebooks/samples/pytorch/iris_classification/serving_pytorch_models_in_ai_platform.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"> Run in Colab\n",
    "    </a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/notebooks/samples/pytorch/iris_classification/serving_pytorch_models_in_ai_platform.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
    "      View on GitHub\n",
    "    </a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Yy8l-95JkRx0"
   },
   "source": [
    "# Overview\n",
    "\n",
    "AI Platform Online Prediction now supports custom python code in to apply custom prediction routines, including custom (stateful) pre/post processing, and/or models not created by the standard supported frameworks (TensorFlow, Keras, Scikit-learn, XGBoost).\n",
    "\n",
    "### Dataset\n",
    "\n",
    "We use the [Iris dataset](https://archive.ics.uci.edu/ml/datasets/Iris)\n",
    "\n",
    "### Objective\n",
    "\n",
    "In this notebook, we show how to deploy a model created by [PyTorch](https://pytorch.org/) using AI Platform  Custom Prediction Code using Iris dataset for a multi-class classification problem.\n",
    "\n",
    "### Costs \n",
    "\n",
    "This tutorial uses billable components of Google Cloud Platform (GCP):\n",
    "\n",
    "* Cloud AI Platform\n",
    "* Cloud Storage\n",
    "\n",
    "Learn about [Cloud AI Platform Prediction](https://cloud.google.com/ai-platform/prediction/pricing)\n",
    "pricing and [Cloud Storage\n",
    "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n",
    "Calculator](https://cloud.google.com/products/calculator/)\n",
    "to generate a cost estimate based on your projected usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "D5AirRX0kRx2"
   },
   "source": [
    "### Set up your local development environment\n",
    "\n",
    "**If you are using Colab or AI Platform Notebooks**, your environment already meets\n",
    "all the requirements to run this notebook. You can skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Otherwise**, make sure your environment meets this notebook's requirements.\n",
    "You need the following:\n",
    "\n",
    "* The Google Cloud SDK\n",
    "* Git\n",
    "* Python 3\n",
    "* virtualenv\n",
    "* Jupyter notebook running in a virtual environment with Python 3\n",
    "\n",
    "The Google Cloud guide to [Setting up a Python development\n",
    "environment](https://cloud.google.com/python/setup) and the [Jupyter\n",
    "installation guide](https://jupyter.org/install) provide detailed instructions\n",
    "for meeting these requirements. The following steps provide a condensed set of\n",
    "instructions:\n",
    "\n",
    "1. [Install and initialize the Cloud SDK.](https://cloud.google.com/sdk/docs/)\n",
    "\n",
    "2. [Install Python 3.](https://cloud.google.com/python/setup#installing_python)\n",
    "\n",
    "3. [Install\n",
    "   virtualenv](https://cloud.google.com/python/setup#installing_and_using_virtualenv)\n",
    "   and create a virtual environment that uses Python 3.\n",
    "\n",
    "4. Activate that environment and run `pip install jupyter` in a shell to install\n",
    "   Jupyter.\n",
    "\n",
    "5. Run `jupyter notebook` in a shell to launch Jupyter.\n",
    "\n",
    "6. Open this notebook in the Jupyter Notebook Dashboard."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up your GCP project\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "1. [Select or create a GCP project.](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.\n",
    "\n",
    "2. [Make sure that billing is enabled for your project.](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
    "\n",
    "3. [Enable the AI Platform APIs and Compute Engine APIs.](https://console.cloud.google.com/flows/enableapi?apiid=ml.googleapis.com,compute_component)\n",
    "\n",
    "4. Enter your project ID in the cell below. Then run the  cell to make sure the\n",
    "Cloud SDK uses the right project for all the commands in this notebook.\n",
    "\n",
    "**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Authenticate your GCP account\n",
    "\n",
    "**If you are using AI Platform Notebooks**, your environment is already\n",
    "authenticated. Skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**If you are using Colab**, run the cell below and follow the instructions\n",
    "when prompted to authenticate your account via oAuth.\n",
    "\n",
    "**Otherwise**, follow these steps:\n",
    "\n",
    "1. In the GCP Console, go to the [**Create service account key**\n",
    "   page](https://console.cloud.google.com/apis/credentials/serviceaccountkey).\n",
    "\n",
    "2. From the **Service account** drop-down list, select **New service account**.\n",
    "\n",
    "3. In the **Service account name** field, enter a name.\n",
    "\n",
    "4. From the **Role** drop-down list, select\n",
    "   **Machine Learning Engine > AI Platform Admin** and\n",
    "   **Storage > Storage Object Admin**.\n",
    "\n",
    "5. Click *Create*. A JSON file that contains your key downloads to your\n",
    "local environment.\n",
    "\n",
    "6. Enter the path to your service account key as the\n",
    "`GOOGLE_APPLICATION_CREDENTIALS` variable in the cell below and run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "no_execute"
    ]
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# If you are running this notebook in Colab, run this cell and follow the\n",
    "# instructions to authenticate your GCP account. This provides access to your\n",
    "# Cloud Storage bucket and lets you submit training jobs and prediction\n",
    "# requests.\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "  from google.colab import auth as google_auth\n",
    "  google_auth.authenticate_user()\n",
    "\n",
    "# If you are running this notebook locally, replace the string below with the\n",
    "# path to your service account key and run this cell to authenticate your GCP\n",
    "# account.\n",
    "else:\n",
    "  %env GOOGLE_APPLICATION_CREDENTIALS ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## PIP Install Packages and dependencies\n",
    "\n",
    "Before we start let's install pytorch and gcloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 5726,
     "status": "ok",
     "timestamp": 1538245666639,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "b_4YHbyCkRx2",
    "outputId": "7caa60de-641d-4cb4-aab1-38ba0e9db975"
   },
   "outputs": [],
   "source": [
    "! pip install torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5pHlrb0slo3i"
   },
   "source": [
    "If you are running this notebook in Colab, run the following cell to authenticate your Google Cloud Platform user account"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 5530,
     "status": "ok",
     "timestamp": 1538245790469,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "lQq3X_e_kRx9",
    "outputId": "eb142794-f9c9-4baf-e584-4629c7a90c72"
   },
   "outputs": [],
   "source": [
    "PROJECT_ID = '[your-project-id]' # TODO (Set to your GCP Project name)\n",
    "!gcloud config set project {PROJECT_ID}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUCKET_NAME = '[your-bucket-name]' # TODO (Set to your GCS Bucket name)\n",
    "REGION = 'us-central1' #@param {type:\"string\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "zgIaQaRakRyC"
   },
   "source": [
    "## 3. Download iris data\n",
    "In this example, we want to build a classifier for the simple [iris dataset](https://archive.ics.uci.edu/ml/datasets/iris). So first, we download the data csv file locally."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5zu_lz-RkRyD"
   },
   "outputs": [],
   "source": [
    "!mkdir data\n",
    "!mkdir models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "LOCAL_DATA_DIR = \"data/iris.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 985,
     "status": "ok",
     "timestamp": 1538245803249,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "CNz7zOYYkRyG",
    "outputId": "fe3856f7-3b9f-46fa-eb5f-a86d122b3a8c"
   },
   "outputs": [],
   "source": [
    "from urllib.request import urlretrieve\n",
    "\n",
    "urlretrieve(\"https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data\", LOCAL_DATA_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fYg0djCukRyK"
   },
   "source": [
    "# Part 1: Build a PyTorch NN Classifier\n",
    "\n",
    "Make sure that pytorch package is [installed](https://pytorch.org/get-started/locally/)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 7065,
     "status": "ok",
     "timestamp": 1538245839403,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "oB3betHBkRyL",
    "outputId": "28de856b-1663-47c1-a22c-cebafdd13a36"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "\n",
    "print('PyTorch Version: {}'.format(torch.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qr0w3nDvkRyO"
   },
   "source": [
    "## 1. Load Data \n",
    "In this step, we are going to:\n",
    "1. Load the data to Pandas Dataframe.\n",
    "2. Convert the class feature (species) from string to a numeric indicator.\n",
    "3. Split the Dataframe into input feature (xtrain) and target feature (ytrain)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 509,
     "status": "ok",
     "timestamp": 1538245841842,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "lk810wj8kRyP",
    "outputId": "a6788660-f35a-4324-b34f-60e5c28efc7b"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "CLASS_VOCAB = ['setosa', 'versicolor', 'virginica']\n",
    "\n",
    "datatrain = pd.read_csv(LOCAL_DATA_DIR, names=['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species'])\n",
    "\n",
    "#change string value to numeric\n",
    "datatrain.loc[datatrain['species']=='Iris-setosa', 'species']=0\n",
    "datatrain.loc[datatrain['species']=='Iris-versicolor', 'species']=1\n",
    "datatrain.loc[datatrain['species']=='Iris-virginica', 'species']=2\n",
    "datatrain = datatrain.apply(pd.to_numeric)\n",
    "\n",
    "#change dataframe to array\n",
    "datatrain_array = datatrain.to_numpy()\n",
    "\n",
    "#split x and y (feature and target)\n",
    "xtrain = datatrain_array[:,:4]\n",
    "ytrain = datatrain_array[:,4]\n",
    "\n",
    "input_features = xtrain.shape[1]\n",
    "num_classes = len(CLASS_VOCAB)\n",
    "\n",
    "print('Records loaded: {}'.format(len(xtrain)))\n",
    "print('Number of input features: {}'.format(input_features))\n",
    "print('Number of classes: {}'.format(num_classes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eCYbi-h7kRyU"
   },
   "source": [
    "## 2. Set model parameters\n",
    "You can try different values for **hidden_units** or **learning_rate**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qvyvyolgkRyU"
   },
   "outputs": [],
   "source": [
    "HIDDEN_UNITS = 10\n",
    "LEARNING_RATE = 0.1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JzAVqhvCkRyX"
   },
   "source": [
    "## 3. Define the PyTorch NN model\n",
    "\n",
    "Here, we build a a neural network with one hidden layer, and a Softmax output layer for classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4BP-U3gPkRyX"
   },
   "outputs": [],
   "source": [
    "model = torch.nn.Sequential(\n",
    "    torch.nn.Linear(input_features, HIDDEN_UNITS),\n",
    "    torch.nn.Sigmoid(),\n",
    "    torch.nn.Linear(HIDDEN_UNITS, num_classes),\n",
    "    torch.nn.Softmax()\n",
    ")\n",
    "\n",
    "loss_metric = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(),lr=LEARNING_RATE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bM79nn7HkRyZ"
   },
   "source": [
    "## 4. Train the model\n",
    "We are going to train the model for **num_epoch** epochs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 258
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 6461,
     "status": "ok",
     "timestamp": 1538245853213,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "ElikXuyokRya",
    "outputId": "c6226007-47fb-4ea7-faee-fdce8db0049a"
   },
   "outputs": [],
   "source": [
    "NUM_EPOCHS = 10000\n",
    "\n",
    "for epoch in range(NUM_EPOCHS):\n",
    "    \n",
    "    x = Variable(torch.Tensor(xtrain).float())\n",
    "    y = Variable(torch.Tensor(ytrain).long())\n",
    "    optimizer.zero_grad()\n",
    "    y_pred = model(x)\n",
    "    loss = loss_metric(y_pred, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    if (epoch) % 1000 == 0:\n",
    "        print('Epoch [{}/{}] Loss: {}'.format(epoch+1, NUM_EPOCHS, round(loss.item(),3)))\n",
    "        \n",
    "print('Epoch [{}/{}] Loss: {}'.format(epoch+1, NUM_EPOCHS, round(loss.item(),3)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pzE7_e_kkRyc"
   },
   "source": [
    "## 5. Save and load the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TVGH-GS9kRye"
   },
   "outputs": [],
   "source": [
    "LOCAL_MODEL_DIR = \"models/model.pt\"\n",
    "\n",
    "\n",
    "torch.save(model, LOCAL_MODEL_DIR)\n",
    "iris_classifier = torch.load(LOCAL_MODEL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nbb4S2s5kRyh"
   },
   "source": [
    "## 6. Test the loaded model for predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3E_i7qpVkRyi"
   },
   "outputs": [],
   "source": [
    "def predict_class(instances):\n",
    "    instances = torch.Tensor(instances)\n",
    "    output = iris_classifier(instances)\n",
    "    _ , predicted = torch.max(output, 1)\n",
    "    return predicted"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7Jdl1MlVkRyn"
   },
   "source": [
    "Get predictions for the first 5 instances in the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 638,
     "status": "ok",
     "timestamp": 1538245859222,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "HEyiYiRckRyn",
    "outputId": "66e52ccc-dac9-47e5-e47c-83572c202557"
   },
   "outputs": [],
   "source": [
    "predicted = predict_class(xtrain[0:5])\n",
    "print([CLASS_VOCAB[class_index] for class_index in predicted])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get the classification accuracy on the training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "accuracy = round(sum(np.array(predict_class(xtrain)) == ytrain)/float(len(ytrain))*100,2)\n",
    "print('Classification accuracy: {} %'.format(accuracy))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vC4N_80wkRyr"
   },
   "source": [
    "## 7. Upload trained model to Cloud Storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8205,
     "status": "ok",
     "timestamp": 1538245870922,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "Ue5XadCLkRys",
    "outputId": "27d285d6-3efd-44bb-ae7d-b1910a46dfbe"
   },
   "outputs": [],
   "source": [
    "GCS_MODEL_DIR='models/pytorch/iris_classifier/'\n",
    "\n",
    "!gsutil -m cp -r {LOCAL_MODEL_DIR} gs://{BUCKET_NAME}/{GCS_MODEL_DIR}\n",
    "!gsutil ls gs://{BUCKET_NAME}/{GCS_MODEL_DIR}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aQP0lGikkRyw"
   },
   "source": [
    "# Part 2: Prepare the Custom Prediction Package\n",
    "\n",
    "1. Implement a model **custom class** for pre/post processing, as well as loading and using your model for prediction.\n",
    "2. Prepare yout **setup.py** file, to include all the modules and packages you need in your custome model class."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KQDLnysXkRyw"
   },
   "source": [
    "## 1. Create the custom model class\n",
    "In the **from_path**, you load the pytorch model that you uploaded to GCS. Then in the **predict** method, you use it for prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 625,
     "status": "ok",
     "timestamp": 1538245874513,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "fg6f8pP8kRyx",
    "outputId": "1a9feaf0-e515-4f3a-fe13-f96d4a1310a8"
   },
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "from google.cloud import storage\n",
    "import torch\n",
    "\n",
    "class PyTorchIrisClassifier(object):\n",
    "    \n",
    "    def __init__(self, model):\n",
    "        self._model = model\n",
    "        self.class_vocab = ['setosa', 'versicolor', 'virginica']\n",
    "        \n",
    "    @classmethod\n",
    "    def from_path(cls, model_dir):\n",
    "        model_file = os.path.join(model_dir,'model.pt')\n",
    "        model = torch.load(model_file)    \n",
    "        return cls(model)\n",
    "\n",
    "    def predict(self, instances, **kwargs):\n",
    "        data = pd.DataFrame(instances).as_matrix()\n",
    "        inputs = torch.Tensor(data)\n",
    "        outputs = self._model(inputs)\n",
    "        _ , predicted = torch.max(outputs, 1)\n",
    "        return [self.class_vocab[class_index] for class_index in predicted]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y22oukOdkRy1"
   },
   "source": [
    "## 2. Create a setup.py module\n",
    "\n",
    "Create a setup.py script to bundle **model.py** in a tarball package. Notice that setup.py does not include the dependencies of model.py in the package. These dependencies are provided to your model version in other ways:\n",
    "\n",
    "`pandas` and `google-cloud-storage` are both included as part of AI Platform Prediction runtime version 1.15.\n",
    "\n",
    "`torch` is provided in a separate package, as described in a following section.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 560,
     "status": "ok",
     "timestamp": 1538245877594,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "QjOxQt2lkRy2",
    "outputId": "3d06578d-d47b-4f30-d927-9225adb4e762"
   },
   "outputs": [],
   "source": [
    "%%writefile setup.py\n",
    "\n",
    "from setuptools import setup\n",
    "\n",
    "REQUIRED_PACKAGES = []\n",
    "\n",
    "setup(\n",
    "    name=\"iris-custom-model\",\n",
    "    version=\"0.1\",\n",
    "    scripts=[\"model.py\"],\n",
    "    install_requires=REQUIRED_PACKAGES\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "AMopNrL0kRy6"
   },
   "source": [
    "## 3. Create the package \n",
    "\n",
    "This will create a .tar.gz package under /dist directory. The name of the package will be (name)-(version).tar.gz where (name) and (version) are the ones specified in the setup.py."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 544
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3243,
     "status": "ok",
     "timestamp": 1538245884136,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "MhDjxWNKkRy7",
    "outputId": "209288b7-cece-4a06-8595-4764921691e0"
   },
   "outputs": [],
   "source": [
    "!python setup.py sdist"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Y0HtInoxkRy_"
   },
   "source": [
    "## 4. Uploaded the package to GCS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 7247,
     "status": "ok",
     "timestamp": 1538245946830,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "1rY8YgOEkRzA",
    "outputId": "bc6db003-2d0c-4470-eaf9-78e51e23109d"
   },
   "outputs": [],
   "source": [
    "GCS_PACKAGE_URI='models/pytorch/packages/iris-custom-model-0.1.tar.gz'\n",
    "\n",
    "!gsutil cp ./dist/iris-custom-model-0.1.tar.gz gs://{BUCKET_NAME}/{GCS_PACKAGE_URI}\n",
    "!gsutil ls gs://{BUCKET_NAME}/{GCS_PACKAGE_URI}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GhF3pYQTkRzE"
   },
   "source": [
    "# Part 3: Deploy the Model to AI Platform for Online Predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-40ZkeakkRzF"
   },
   "source": [
    "## 1. Create AI Platform model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 8718,
     "status": "ok",
     "timestamp": 1538245958274,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "muG2T57skRzG",
    "outputId": "d6747346-1f2b-4382-a8ca-8e9535a6c5be"
   },
   "outputs": [],
   "source": [
    "MODEL_NAME='torch_iris_classifier'\n",
    "MODEL_VERSION='v1'\n",
    "RUNTIME_VERSION='1.15'\n",
    "MODEL_CLASS='model.PyTorchIrisClassifier'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete model version resource\n",
    "! gcloud ai-platform versions delete {MODEL_VERSION} --model {MODEL_NAME} --quiet\n",
    "\n",
    "# Delete model resource\n",
    "! gcloud ai-platform models delete {MODEL_NAME} --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can uncomment to enable logging\n",
    "\n",
    "!gcloud ai-platform models create {MODEL_NAME} --regions {REGION} #--enable-logging --enable-console-logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bNbqAHgFkRzI"
   },
   "source": [
    "## 2. Create AI Platform model version\n",
    "\n",
    "Once you have your custom package ready, you can specify this as an argument when creating a version resource. Note that you need to provide the path to your package (as package-uris) and also the class name that contains your custom predict method (as model-class)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Pytorch compatible packages"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to specify two Python packages when you create your version resource. One of these is the package containing `model.py` that you uploaded to Cloud Storage in a previous step. The other is a package containing the version of PyTorch that you need.\n",
    "\n",
    "Google Cloud provides a collection of PyTorch packages in the `gs://cloud-ai-pytorch` Cloud Storage bucket. These packages are mirrored from the official builds.\n",
    "\n",
    "For this tutorial, use `gs://cloud-ai-pytorch/torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl` as your PyTorch package. This provides your version resource with PyTorch 1.3.1 for Python 3.7, built to run on a CPU in Linux.\n",
    "\n",
    "Use the following command to create your version resource:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "gpXx-nFLkRzJ"
   },
   "outputs": [],
   "source": [
    "! gcloud beta ai-platform versions create {MODEL_VERSION} --model={MODEL_NAME} \\\n",
    "            --origin=gs://{BUCKET_NAME}/{GCS_MODEL_DIR} \\\n",
    "            --python-version=3.7 \\\n",
    "            --runtime-version={RUNTIME_VERSION} \\\n",
    "            --machine-type=mls1-c4-m4 \\\n",
    "            --package-uris=gs://{BUCKET_NAME}/{GCS_PACKAGE_URI},gs://cloud-ai-pytorch/torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl \\\n",
    "            --prediction-class={MODEL_CLASS}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 3411,
     "status": "ok",
     "timestamp": 1538246126033,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "naG7FtQokRzM",
    "outputId": "2df0240e-1445-47e8-9a41-627dfacab273"
   },
   "outputs": [],
   "source": [
    "! gcloud ai-platform versions list --model {MODEL_NAME}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "TQz-JaqpkRzP"
   },
   "source": [
    "# Part 4: AI Platform Online Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Wdx5nTxjkRzR"
   },
   "outputs": [],
   "source": [
    "from googleapiclient import discovery\n",
    "from oauth2client.client import GoogleCredentials\n",
    "\n",
    "credentials = GoogleCredentials.get_application_default()\n",
    "api = discovery.build('ml', 'v1', credentials=credentials)\n",
    "\n",
    "\n",
    "def estimate(project, model_name, version, instances):\n",
    "    request_data = {'instances': instances}\n",
    "    model_url = 'projects/{}/models/{}/versions/{}'.format(project, model_name, version)\n",
    "    response = api.projects().predict(body=request_data, name=model_url).execute()\n",
    "    predictions = response[\"predictions\"]\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 869,
     "status": "ok",
     "timestamp": 1538246147364,
     "user": {
      "displayName": "Khalid Salama",
      "photoUrl": "",
      "userId": "16897751573918541504"
     },
     "user_tz": -60
    },
    "id": "ETKlF4yHkRzT",
    "outputId": "f1089ffd-c6c0-4db9-cf64-02bf414a5701"
   },
   "outputs": [],
   "source": [
    "instances = [\n",
    "    [6.8, 2.8, 4.8, 1.4],\n",
    "    [6. , 3.4, 4.5, 1.6]\n",
    "]\n",
    "\n",
    "predictions = estimate(instances=instances\n",
    "                     ,project=PROJECT_ID\n",
    "                     ,model_name=MODEL_NAME\n",
    "                     ,version=MODEL_VERSION)\n",
    "\n",
    "print(predictions)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleaning up\n",
    "\n",
    "To clean up all GCP resources used in this project, you can [delete the GCP\n",
    "project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.\n",
    "\n",
    "Alternatively, you can clean up individual resources by running the following\n",
    "commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete model version resource\n",
    "! gcloud ai-platform versions delete {MODEL_VERSION} --model {MODEL_NAME} --quiet\n",
    "\n",
    "# Delete model resource\n",
    "! gcloud ai-platform models delete {MODEL_NAME} --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nk8oZ2zPkRzW"
   },
   "source": [
    "# Questions? Feedback?\n",
    "Feel free to send us an email (cloudml-feedback@google.com) if you run into any issues or have any questions/feedback!"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Serving PyTorch Models with AI Platform  Custom Prediction Code.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
